{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 03_TransE_l1_edge_similarity_based_on_link_recommendation_results\n",
    "#\n",
    "# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on February 6, 2023\n",
    "# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 14, 2023\n",
    "#\n",
    "# 该脚本展示了如何分析 TransE_l1 的关系类型推荐相似性 (Link Type Recommendation Similarity).\n",
    "#\n",
    "# 需要的包:\n",
    "#          numpy\n",
    "#          csv\n",
    "#          torch\n",
    "#\n",
    "# 需要的文件:\n",
    "#          ../../data/drkg/drkg.tsv\n",
    "#          ../../data/drkg/entities.tsv\n",
    "#          ../../data/drkg/relations.tsv\n",
    "#          ../01-model/ckpts/TransE_l1_All_DRKG_0/All_DRKG_TransE_l1_entity.npy\n",
    "#          ../01-model/ckpts/TransE_l1_All_DRKG_0/All_DRKG_TransE_l1_relation.npy\n",
    "#\n",
    "# 源教程链接: https://github.com/gnn4dr/DRKG/blob/master/embedding_analysis/Edge_similarity_based_on_link_recommendation_results.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DRKG Relation Similarity Analysis based on link recommendations\n",
    "\n",
    "本笔记本基于 DRKG 中不同关系类型的推荐结果, 对其进行相似性分析. 具体而言, 对于某个节点, 我们预测某个关系类型的 K 个最相似的邻居. 然后, 我们对所有关系类型重复此预测. 预测的邻居有显著重叠的关系类型将更相似."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入需要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import csv\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def seed_torch(seed=42):\n",
    "    seed = int(seed)\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    torch.backends.cudnn.enabled = True\n",
    "\n",
    "seed_torch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./result/overlapping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义用于对边进行评分的函数, 这应该与用于学习嵌入的函数相一致.\n",
    "\n",
    "DGL-KE 官方实现 TransE_l1 评分函数的代码在:\n",
    "\n",
    "- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/ke_model.py 886 - 893 行\n",
    "\n",
    "- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/pytorch/score_fun.py 54 - 59 行\n",
    "\n",
    "OpenKE 实现 TransE_l1 评分函数的代码在:\n",
    "\n",
    "- https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/TransE.py\n",
    "\n",
    "其中两者的实现代码基本一样."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as fn\n",
    "\n",
    "gamma = 18.0\n",
    "def transE_l1(head, rel, tail):\n",
    "    score = head + rel - tail\n",
    "    return gamma - torch.norm(score, p=1, dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Mapping files\n",
    "\n",
    "加载映射文件, 同时加载实体和关系嵌入."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "97238\n",
      "107\n"
     ]
    }
   ],
   "source": [
    "entity2id = {}\n",
    "with open(\"../../data/drkg/entities.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id','entity'])\n",
    "    for row_val in reader:\n",
    "        id = row_val['id']\n",
    "        entity2id[row_val['entity']] = int(id)\n",
    "\n",
    "print(len(entity2id))\n",
    "\n",
    "rel2id = {}\n",
    "with open(\"../../data/drkg/relations.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id','relation'])\n",
    "    for row_val in reader:\n",
    "        id = row_val['id']\n",
    "        rel2id[row_val['relation']] = int(id)\n",
    "\n",
    "print(len(rel2id))\n",
    "\n",
    "node_emb = np.load('../01-model/ckpts/TransE_l1_All_DRKG_0/All_DRKG_TransE_l1_entity.npy')\n",
    "rel_emb = np.load('../01-model/ckpts/TransE_l1_All_DRKG_0/All_DRKG_TransE_l1_relation.npy')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading triplets\n",
    "\n",
    "加载三元组, 映射成 ID."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "head_ids = []\n",
    "with open(\"../../data/drkg/drkg.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['head', 'rel', 'tail'])\n",
    "    for row_val in reader:\n",
    "        head = row_val['head']\n",
    "        head_id = entity2id[head]     \n",
    "        head_ids.append(head_id)\n",
    "        \n",
    "head_ids = np.array(head_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Link prediction\n",
    "\n",
    "指定进行链接预测种子节点的个数."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = 100\n",
    "device = torch.device('cpu')\n",
    "with torch.no_grad():\n",
    "    node_emb = torch.tensor(node_emb).to(device)\n",
    "    rel_emb = torch.tensor(rel_emb).to(device)\n",
    "    head_ids = torch.tensor(head_ids).to(device)\n",
    "\n",
    "    head_embedding = node_emb[head_ids]\n",
    "    \n",
    "    # 选择 L 个随机 heads.\n",
    "    perm = torch.randperm(head_ids.shape[0])\n",
    "    seeds = head_ids[perm[:L]]\n",
    "    seed_heads = node_emb[seeds]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "预测所选种子节点和所有其他节点之间每个关系类型的得分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([97238, 400])\n",
      "100 torch.Size([97238])\n"
     ]
    }
   ],
   "source": [
    "flag = True\n",
    "scores = {}\n",
    "for rel in rel2id.keys():\n",
    "    rel_id = rel2id[rel]\n",
    "    rel_embedding = ((rel_emb[rel_id]).repeat(node_emb.shape[0], 1))\n",
    "    \n",
    "    scores[rel] =[transE_l1((seed_heads[i].repeat(node_emb.shape[0], 1)),\n",
    "                            rel_embedding, node_emb) for i in range(seed_heads.shape[0])]\n",
    "    if flag:\n",
    "        print(rel_embedding.shape)\n",
    "        print(len(scores[rel]), scores[rel][0].shape)\n",
    "        flag = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Top K link predicition\n",
    "\n",
    "指定得分最高的邻居的数量, 以评估链接预测的重叠."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100 10\n"
     ]
    }
   ],
   "source": [
    "flag = True\n",
    "K = 10\n",
    "top_neighbors={}\n",
    "for rel in scores.keys():\n",
    "    top_neighbors[rel] = [torch.argsort(score, descending = True)[:K] for score in scores[rel]]\n",
    "    if flag:\n",
    "        print(len(top_neighbors[rel]), len(top_neighbors[rel][0]))\n",
    "        flag = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overlap among predicted neighbors\n",
    "\n",
    "计算每种关系类型的预测邻居节点的重叠"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlap_of_predicted_neighbors = []\n",
    "keys = list(scores.keys()) \n",
    "\n",
    "for i in range(len(keys)):\n",
    "    for j in range(i + 1, len(keys)):\n",
    "        rel_1 = keys[i]\n",
    "        rel_2 = keys[j]\n",
    "        neighbors_seed_heads_1 = top_neighbors[rel_1]\n",
    "        neighbors_seed_heads_2 = top_neighbors[rel_2]\n",
    "        jacard = 0\n",
    "        for k in range(len(neighbors_seed_heads_1)):\n",
    "            neighbors_1 = list(neighbors_seed_heads_1[k].cpu().numpy())\n",
    "            neighbors_2 = list(neighbors_seed_heads_2[k].cpu().numpy())\n",
    "            jacard += float(len(set(neighbors_1).intersection(set(neighbors_2)))\n",
    "                            / len(set(neighbors_1).union(set(neighbors_2))))\n",
    "        jacard = jacard / len(neighbors_seed_heads_1)\n",
    "        overlap_of_predicted_neighbors.append([rel_1, rel_2, jacard])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "存储排序的重叠结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 降序排列\n",
    "results = (sorted(overlap_of_predicted_neighbors, key=lambda x: float(x[2])))[::-1]\n",
    "results_store = [\"{}\\t{}\\t{}\\n\".format(result[0], result[1], result[2]) for result in results]\n",
    "results_store = [\"relation1\\trelation2\\tpercentage of overlapping predicted edges\\n\"] + results_store\n",
    "file = \"./result/overlapping/TransE_l1_percentage_of_overlapping_predicted_edges_per_edge_pair\" + str(K) + \".tsv\"\n",
    "\n",
    "with open(file, 'w+') as f:\n",
    "    f.writelines(results_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
